Skip to content

Conversation

@kiritorl
Copy link
Contributor

@kiritorl kiritorl commented Jan 20, 2026

Summary

This PR modifies the NPU test reference for KLDivLoss. Since the native NPU KLDivLoss operator does not support gradients w.r.t. the target #1021 it caused failures in test_jsd.py (where input and target are swapped when beta != 0).

To resolve this, I replaced the native operator usage with a custom implementation using basic math operations. This allows correct gradient computation for the target and aligns the x1.grad results with the Triton kernel implementation.

Testing Done

I tested test_jsd,test_fused_linear_jsd by following method and all cases passed:

pytest -v test/transformers/test_jsd.py
pytest -v test/transformers/test_fused_linear_jsd.py

Hardware Type: Ascend NPU 910B3

  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@kiritorl kiritorl changed the title fix(npu): update the native KLDivLoss implementation for comparison. [NPU]: update the native KLDivLoss implementation for comparison. (eg.)test_jsd.py Jan 20, 2026
@kiritorl
Copy link
Contributor Author

kiritorl commented Jan 20, 2026

Test results on NPU before:

error in
test/transformers/test_jsd.py:160: in _test_correctness_once
assert_verbose_allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)

tensor1 = tensor([[-3.6322e-08, -8.1956e-08, -4.1211e-08,  ..., -1.8999e-07,
         -4.9593e-08, -4.3772e-08],
        [-6.379...-5.6345e-08, -6.3796e-08,  ..., -1.6182e-08,
         -8.1956e-08, -1.2293e-07]], device='npu:0', dtype=torch.bfloat16)
tensor2 = tensor([[-1.0186e-08,  2.9686e-08,  1.0885e-08,  ...,  1.1525e-08,
          1.6182e-08,  6.3155e-09],
        [ 1.397... 1.7229e-08,  3.8883e-08,  ..., -4.5402e-09,
          3.2363e-08,  9.2550e-09]], device='npu:0', dtype=torch.bfloat16)

Test results on NPU after:

tensor1: tensor([[-1.0186e-08,  2.9686e-08,  1.0885e-08,  ...,  1.1525e-08,
          1.6182e-08,  6.3155e-09],
        [ 1.3970e-08,  5.2620e-08,  8.2888e-08,  ..., -5.8790e-09,
         -6.5775e-09,  4.7497e-08],
        [-1.1059e-08, -1.8859e-08, -1.6298e-08,  ..., -8.2655e-09,
          5.5297e-09,  9.8720e-08],
        ...,
        [-1.0012e-08,  1.8068e-07,  0.0000e+00,  ..., -1.2689e-08,
          1.7229e-08, -2.4214e-08],
        [-7.1304e-09,  1.2515e-08,  4.7963e-08,  ..., -1.4808e-07,
          2.2468e-08,  3.3324e-09],
        [-4.1444e-08,  1.7229e-08,  3.8883e-08,  ..., -4.5402e-09,
          3.2363e-08,  9.2550e-09]], device='npu:0', dtype=torch.bfloat16)
tensor2: tensor([[-1.0186e-08,  2.9686e-08,  1.0885e-08,  ...,  1.1525e-08,
          1.6182e-08,  6.3155e-09],
        [ 1.3970e-08,  5.2620e-08,  8.2888e-08,  ..., -5.8790e-09,
         -6.5775e-09,  4.7497e-08],
        [-1.1059e-08, -1.8859e-08, -1.6298e-08,  ..., -8.2655e-09,
          5.5297e-09,  9.8720e-08],
        ...,
        [-1.0012e-08,  1.8068e-07,  0.0000e+00,  ..., -1.2689e-08,
          1.7229e-08, -2.4214e-08],
        [-7.1304e-09,  1.2515e-08,  4.7963e-08,  ..., -1.4808e-07,
          2.2468e-08,  3.3324e-09],
        [-4.1444e-08,  1.7229e-08,  3.8883e-08,  ..., -4.5402e-09,
          3.2363e-08,  9.2550e-09]], device='npu:0', dtype=torch.bfloat16)
PASSED

Copy link
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch, you can report the issue to torch npu team.

Overall lgtm, just a nit change for documentation.

set_seed(42)


class CustomKLDivLoss(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Add a docstring to explain why we need a custom KLDivLoss

Since it's an npu-exclusive issue, I also name it with NPU.

Suggested change
class CustomKLDivLoss(torch.nn.Module):
class NPUKLDivLoss(torch.nn.Module):
"""
A custom KLDivLoss for NPU.
On NPU devices, torch.nn.KLDivLoss does not compute gradients with respect to the target.
This leads to incorrect gradient computation when the target depends on the input,
such as in JSD or reverse KLDiv.
See https://github.com/linkedin/Liger-Kernel/issues/1021 for more details.
"""

@kiritorl
Copy link
Contributor Author

Thanks. I have reported the issue to the Torch NPU team and updated the documentation as requested.

Copy link
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

Copy link
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops my bad, there's a formatting issue. You can also install pre-commit hooks to fix it automatically, see

3. **Install pre-commit hooks using [`prek`](https://prek.j178.dev/), a `pre-commit` alternative built in rust**
```
prek install
```
Run pre-commit check without committing (`-a` is equivalent to `--all-files`)
```
prek run -a
```

@kiritorl
Copy link
Contributor Author

My bad, I should have noticed that. Thanks for the tip!

@Tcc0403 Tcc0403 merged commit b708f79 into linkedin:main Jan 20, 2026
3 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants